[JAX] Handle meshs set with jax.set_mesh#2532
Conversation
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci L1 L2 jax |
Greptile SummaryFixed a bug where TransformerEngine/JAX couldn't query mesh information from meshes set via Key Changes:
Impact:
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant JAX
participant TE as TransformerEngine
participant PXLA as PXLA Thread Resources
participant GetAbstractMesh as jax.sharding.get_abstract_mesh()
Note over User,GetAbstractMesh: Scenario 1: Mesh set via 'with mesh:' context
User->>JAX: with mesh:
JAX->>PXLA: Set physical_mesh
User->>TE: Call TE function (e.g., is_mesh_available())
TE->>TE: _get_mesh()
TE->>PXLA: Check physical_mesh
PXLA-->>TE: Returns mesh (not empty)
TE-->>User: Returns mesh
Note over User,GetAbstractMesh: Scenario 2: Mesh set via jax.set_mesh()
User->>JAX: jax.set_mesh(mesh)
JAX->>GetAbstractMesh: Store mesh in abstract context
Note over PXLA: physical_mesh remains None or empty
User->>TE: Call TE function (e.g., is_mesh_available())
TE->>TE: _get_mesh()
TE->>PXLA: Check physical_mesh
PXLA-->>TE: Returns None or empty mesh
TE->>GetAbstractMesh: Fallback to get_abstract_mesh()
GetAbstractMesh-->>TE: Returns mesh from jax.set_mesh()
TE-->>User: Returns mesh
|
There was a problem hiding this comment.
Additional Comments (1)
-
transformer_engine/jax/sharding.py, line 40-46 (link)style: check if the fallback to
get_abstract_mesh()should happen whenphysical_meshis explicitly set to an empty (but non-None) mesh. Currently, if someone useswith empty_mesh:, it will fallback toget_abstract_mesh()which might return a different mesh set viajax.set_mesh(). Consider whether the condition should beif mesh is None:instead ofif mesh is not None and not mesh.empty:. Ifphysical_meshis explicitly set to an empty mesh object (not None), should the code still fallback toget_abstract_mesh()?
1 file reviewed, 1 comment
Description
Fixes issue that we cannot query jax Mesh info from Meshs set via
jax.set_meshType of change
Changes
get_abstract_mesh()in addition to pxla thread resources when looking for mesh contextChecklist: